{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "oUSlPh7xdYwV"
      },
      "outputs": [],
      "source": [
        "# Copyright 2021 Google Inc.\n",
        "\n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "\n",
        "#     http://www.apache.org/licenses/LICENSE-2.0\n",
        "\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "9n1hzmoE9onA"
      },
      "source": [
        "# This code supports the publication \"Using Deep Learning to Annotate the Protein Universe\".\n",
        "[preprint link](https://doi.org/10.1101/626507)\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QX30QL0bIiPs"
      },
      "source": [
        "**Note**: We recommend you enable a free GPU by going:\n",
        "\n",
        "\u003e **Runtime**   →   **Change runtime type**   →   **Hardware Accelerator: GPU**\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EzXGgAYDN2Os"
      },
      "source": [
        "# Set-up"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "v-6Ufto98_Ro"
      },
      "source": [
        "## Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "P6_SeJumHUGK"
      },
      "outputs": [],
      "source": [
        "import json\n",
        "import numpy as np\n",
        "import tensorflow.compat.v1 as tf\n",
        "import tqdm\n",
        "\n",
        "# Suppress noisy log messages.\n",
        "from tensorflow.python.util import deprecation\n",
        "deprecation._PRINT_DEPRECATION_WARNINGS = False"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Pca8JUJH9GyK"
      },
      "source": [
        "## Library functions: convert sequence to one-hot array (input to model)\n",
        "\n",
        "The following library functions are copied from the github repo so as to make this colab dependency-free: no installation of packages is required - just the standard colab kernel."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Lp-3Ccx09IJ7"
      },
      "outputs": [],
      "source": [
        "AMINO_ACID_VOCABULARY = [\n",
        "    'A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', 'R',\n",
        "    'S', 'T', 'V', 'W', 'Y'\n",
        "]\n",
        "def residues_to_one_hot(amino_acid_residues):\n",
        "  \"\"\"Given a sequence of amino acids, return one hot array.\n",
        "\n",
        "  Supports ambiguous amino acid characters B, Z, and X by distributing evenly\n",
        "  over possible values, e.g. an 'X' gets mapped to [.05, .05, ... , .05].\n",
        "\n",
        "  Supports rare amino acids by appropriately substituting. See\n",
        "  normalize_sequence_to_blosum_characters for more information.\n",
        "\n",
        "  Supports gaps and pads with the '.' and '-' characters; which are mapped to\n",
        "  the zero vector.\n",
        "\n",
        "  Args:\n",
        "    amino_acid_residues: string. consisting of characters from\n",
        "      AMINO_ACID_VOCABULARY\n",
        "\n",
        "  Returns:\n",
        "    A numpy array of shape (len(amino_acid_residues),\n",
        "     len(AMINO_ACID_VOCABULARY)).\n",
        "\n",
        "  Raises:\n",
        "    ValueError: if sparse_amino_acid has a character not in the vocabulary + X.\n",
        "  \"\"\"\n",
        "  to_return = []\n",
        "  normalized_residues = amino_acid_residues.replace('U', 'C').replace('O', 'X')\n",
        "  for char in normalized_residues:\n",
        "    if char in AMINO_ACID_VOCABULARY:\n",
        "      to_append = np.zeros(len(AMINO_ACID_VOCABULARY))\n",
        "      to_append[AMINO_ACID_VOCABULARY.index(char)] = 1.\n",
        "      to_return.append(to_append)\n",
        "    elif char == 'B':  # Asparagine or aspartic acid.\n",
        "      to_append = np.zeros(len(AMINO_ACID_VOCABULARY))\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('D')] = .5\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('N')] = .5\n",
        "      to_return.append(to_append)\n",
        "    elif char == 'Z':  # Glutamine or glutamic acid.\n",
        "      to_append = np.zeros(len(AMINO_ACID_VOCABULARY))\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('E')] = .5\n",
        "      to_append[AMINO_ACID_VOCABULARY.index('Q')] = .5\n",
        "      to_return.append(to_append)\n",
        "    elif char == 'X':\n",
        "      to_return.append(\n",
        "          np.full(len(AMINO_ACID_VOCABULARY), 1. / len(AMINO_ACID_VOCABULARY)))\n",
        "    elif char == _PFAM_GAP_CHARACTER:\n",
        "      to_return.append(np.zeros(len(AMINO_ACID_VOCABULARY)))\n",
        "    else:\n",
        "      raise ValueError('Could not one-hot code character {}'.format(char))\n",
        "  return np.array(to_return)\n",
        "\n",
        "def _test_residues_to_one_hot():\n",
        "  expected = np.zeros((3, 20))\n",
        "  expected[0, 0] = 1.   # Amino acid A\n",
        "  expected[1, 1] = 1.   # Amino acid C\n",
        "  expected[2, :] = .05  # Amino acid X\n",
        "\n",
        "  actual = residues_to_one_hot('ACX')\n",
        "  np.testing.assert_allclose(actual, expected)\n",
        "_test_residues_to_one_hot()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FxYneKVaHxLC"
      },
      "outputs": [],
      "source": [
        "def pad_one_hot_sequence(sequence: np.ndarray,\n",
        "                         target_length: int) -\u003e np.ndarray:\n",
        "  \"\"\"Pads one hot sequence [seq_len, num_aas] in the seq_len dimension.\"\"\"\n",
        "  sequence_length = sequence.shape[0]\n",
        "  pad_length = target_length - sequence_length\n",
        "  if pad_length \u003c 0:\n",
        "    raise ValueError(\n",
        "        'Cannot set a negative amount of padding. Sequence length was {}, target_length was {}.'\n",
        "        .format(sequence_length, target_length))\n",
        "  pad_values = [[0, pad_length], [0, 0]]\n",
        "  return np.pad(sequence, pad_values, mode='constant')\n",
        "\n",
        "def _test_pad_one_hot():\n",
        "  input_one_hot = residues_to_one_hot('ACX')\n",
        "  expected = np.array(input_one_hot.tolist() + np.zeros((4, 20)).tolist())\n",
        "  actual = pad_one_hot_sequence(input_one_hot, 7)\n",
        "\n",
        "  np.testing.assert_allclose(expected, actual)\n",
        "_test_pad_one_hot()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eYhpG2-K5CLx"
      },
      "outputs": [],
      "source": [
        "def batch_iterable(iterable, batch_size):\n",
        "  \"\"\"Yields batches from an iterable.\n",
        "\n",
        "  If the number of elements in the iterator is not a multiple of batch size,\n",
        "  the last batch will have fewer elements.\n",
        "\n",
        "  Args:\n",
        "    iterable: a potentially infinite iterable.\n",
        "    batch_size: the size of batches to return.\n",
        "\n",
        "  Yields:\n",
        "    array of length batch_size, containing elements, in order, from iterable.\n",
        "\n",
        "  Raises:\n",
        "    ValueError: if batch_size \u003c 1.\n",
        "  \"\"\"\n",
        "  if batch_size \u003c 1:\n",
        "    raise ValueError(\n",
        "        'Cannot have a batch size of less than 1. Received: {}'.format(\n",
        "            batch_size))\n",
        "\n",
        "  current = []\n",
        "  for item in iterable:\n",
        "    if len(current) == batch_size:\n",
        "      yield current\n",
        "      current = []\n",
        "    current.append(item)\n",
        "\n",
        "  # Prevent yielding an empty batch. Instead, prefer to end the generation.\n",
        "  if current:\n",
        "    yield current\n",
        "\n",
        "def _test_batch_iterable():\n",
        "  itr = [1, 2, 3]\n",
        "  batched_itr = list(batch_iterable(itr, 2))\n",
        "  assert batched_itr == [[1, 2], [3]]\n",
        "\n",
        "_test_batch_iterable()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "egM-2RHJ9Bnm"
      },
      "source": [
        "## Download model and vocabulary"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "Up68255TDGKe",
        "outputId": "15e90ba7-6363-400b-8994-cc0b85b1e080"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2021-09-22 14:58:00--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/models/single_domain_per_sequence_zipped_models/trained_model_pfam_32.0_vocab.json\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 74.125.31.128, 74.125.141.128, 173.194.210.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|74.125.31.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 197219 (193K) [application/octet-stream]\n",
            "Saving to: ‘trained_model_pfam_32.0_vocab.json’\n",
            "\n",
            "\r          trained_m   0%[                    ]       0  --.-KB/s               \rtrained_model_pfam_ 100%[===================\u003e] 192.60K  --.-KB/s    in 0.007s  \n",
            "\n",
            "2021-09-22 14:58:00 (28.0 MB/s) - ‘trained_model_pfam_32.0_vocab.json’ saved [197219/197219]\n",
            "\n"
          ]
        }
      ],
      "source": [
        "# Get a TensorFlow SavedModel\n",
        "!wget -qN https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/models/single_domain_per_sequence_zipped_models/seed_random_32.0/5356760.tar.gz\n",
        "# unzip\n",
        "!tar xzf 5356760.tar.gz\n",
        "# Get the vocabulary for the model, which tells you which output index means which family\n",
        "!wget https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/models/single_domain_per_sequence_zipped_models/trained_model_pfam_32.0_vocab.json"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "XM3H7Y-rGPg8",
        "outputId": "611e4519-2f21-4ac6-9202-30ec7ce5ee8c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "5356760.tar.gz\n",
            "\n",
            "trn-_cnn_random__random_sp_gpu-cnn_for_random_pfam-5356760:\n",
            "saved_model.pb\tvariables\n"
          ]
        }
      ],
      "source": [
        "# Find the unzipped path\n",
        "!ls *5356760*"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MYaOgeZx9JrV"
      },
      "source": [
        "## Load the model into TensorFlow"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Bp1imZTA6L_0"
      },
      "outputs": [],
      "source": [
        "sess = tf.Session()\n",
        "graph = tf.Graph()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "zbZfobaq6bkI",
        "outputId": "025a5050-8d79-4517-d43c-e550aa106c96"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "INFO:tensorflow:Restoring parameters from trn-_cnn_random__random_sp_gpu-cnn_for_random_pfam-5356760/variables/variables\n"
          ]
        }
      ],
      "source": [
        "with graph.as_default():\n",
        "  saved_model = tf.saved_model.load(sess, ['serve'], 'trn-_cnn_random__random_sp_gpu-cnn_for_random_pfam-5356760')"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "IR0Gl4Mx9Mgm"
      },
      "source": [
        "## Load tensors for class prediction"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zbapL8YS7FpE"
      },
      "outputs": [],
      "source": [
        "top_pick_signature = saved_model.signature_def['serving_default']\n",
        "top_pick_signature_tensor_name = top_pick_signature.outputs['output'].name\n",
        "\n",
        "sequence_input_tensor_name = saved_model.signature_def['confidences'].inputs['sequence'].name\n",
        "sequence_lengths_input_tensor_name = saved_model.signature_def['confidences'].inputs['sequence_length'].name"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SXLMiRbMAHEi"
      },
      "source": [
        "## Load mapping from neural network outputs to Pfam family names "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CSNCbb5aAEx_"
      },
      "outputs": [],
      "source": [
        "with open('trained_model_pfam_32.0_vocab.json') as f:\n",
        "  vocab = json.loads(f.read())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "An9BVMs_Dn6d"
      },
      "source": [
        "# Download data for inference"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wyOMVA5EDsP3",
        "outputId": "97282fc7-d2da-44b1-d461-ed3cb1a7f48c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "--2021-09-22 14:58:08--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00000-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.11.128, 74.125.26.128, 172.217.204.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.11.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6106511 (5.8M) [application/octet-stream]\n",
            "Saving to: ‘data-00000-of-00010’\n",
            "\n",
            "\rdata-00000-of-00010   0%[                    ]       0  --.-KB/s               \rdata-00000-of-00010 100%[===================\u003e]   5.82M  --.-KB/s    in 0.04s   \n",
            "\n",
            "2021-09-22 14:58:08 (160 MB/s) - ‘data-00000-of-00010’ saved [6106511/6106511]\n",
            "\n",
            "--2021-09-22 14:58:08--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00001-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.214.128, 173.194.216.128, 173.194.217.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.214.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6251203 (6.0M) [application/octet-stream]\n",
            "Saving to: ‘data-00001-of-00010’\n",
            "\n",
            "data-00001-of-00010 100%[===================\u003e]   5.96M  --.-KB/s    in 0.03s   \n",
            "\n",
            "2021-09-22 14:58:08 (173 MB/s) - ‘data-00001-of-00010’ saved [6251203/6251203]\n",
            "\n",
            "--2021-09-22 14:58:08--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00002-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.214.128, 173.194.215.128, 173.194.216.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.214.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6161739 (5.9M) [application/octet-stream]\n",
            "Saving to: ‘data-00002-of-00010’\n",
            "\n",
            "data-00002-of-00010 100%[===================\u003e]   5.88M  --.-KB/s    in 0.03s   \n",
            "\n",
            "2021-09-22 14:58:08 (173 MB/s) - ‘data-00002-of-00010’ saved [6161739/6161739]\n",
            "\n",
            "--2021-09-22 14:58:08--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00003-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.213.128, 173.194.214.128, 173.194.216.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.213.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6132686 (5.8M) [application/octet-stream]\n",
            "Saving to: ‘data-00003-of-00010’\n",
            "\n",
            "data-00003-of-00010 100%[===================\u003e]   5.85M  --.-KB/s    in 0.04s   \n",
            "\n",
            "2021-09-22 14:58:09 (133 MB/s) - ‘data-00003-of-00010’ saved [6132686/6132686]\n",
            "\n",
            "--2021-09-22 14:58:09--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00004-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 108.177.12.128, 172.217.193.128, 172.217.204.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|108.177.12.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6260246 (6.0M) [application/octet-stream]\n",
            "Saving to: ‘data-00004-of-00010’\n",
            "\n",
            "data-00004-of-00010 100%[===================\u003e]   5.97M  --.-KB/s    in 0.03s   \n",
            "\n",
            "2021-09-22 14:58:09 (179 MB/s) - ‘data-00004-of-00010’ saved [6260246/6260246]\n",
            "\n",
            "--2021-09-22 14:58:09--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00005-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.214.128, 173.194.215.128, 173.194.216.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.214.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6128724 (5.8M) [application/octet-stream]\n",
            "Saving to: ‘data-00005-of-00010’\n",
            "\n",
            "data-00005-of-00010 100%[===================\u003e]   5.84M  --.-KB/s    in 0.03s   \n",
            "\n",
            "2021-09-22 14:58:09 (172 MB/s) - ‘data-00005-of-00010’ saved [6128724/6128724]\n",
            "\n",
            "--2021-09-22 14:58:09--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00006-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.204.128, 172.217.203.128, 172.253.123.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.204.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6290075 (6.0M) [application/octet-stream]\n",
            "Saving to: ‘data-00006-of-00010’\n",
            "\n",
            "data-00006-of-00010 100%[===================\u003e]   6.00M  --.-KB/s    in 0.04s   \n",
            "\n",
            "2021-09-22 14:58:09 (135 MB/s) - ‘data-00006-of-00010’ saved [6290075/6290075]\n",
            "\n",
            "--2021-09-22 14:58:09--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00007-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.193.128, 172.217.204.128, 172.217.203.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.193.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6140375 (5.9M) [application/octet-stream]\n",
            "Saving to: ‘data-00007-of-00010’\n",
            "\n",
            "data-00007-of-00010 100%[===================\u003e]   5.86M  --.-KB/s    in 0.04s   \n",
            "\n",
            "2021-09-22 14:58:10 (167 MB/s) - ‘data-00007-of-00010’ saved [6140375/6140375]\n",
            "\n",
            "--2021-09-22 14:58:10--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00008-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 172.217.204.128, 172.217.203.128, 74.125.141.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|172.217.204.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6132396 (5.8M) [application/octet-stream]\n",
            "Saving to: ‘data-00008-of-00010’\n",
            "\n",
            "data-00008-of-00010 100%[===================\u003e]   5.85M  --.-KB/s    in 0.04s   \n",
            "\n",
            "2021-09-22 14:58:10 (152 MB/s) - ‘data-00008-of-00010’ saved [6132396/6132396]\n",
            "\n",
            "--2021-09-22 14:58:10--  https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-00009-of-00010\n",
            "Resolving storage.googleapis.com (storage.googleapis.com)... 173.194.214.128, 173.194.216.128, 173.194.217.128, ...\n",
            "Connecting to storage.googleapis.com (storage.googleapis.com)|173.194.214.128|:443... connected.\n",
            "HTTP request sent, awaiting response... 200 OK\n",
            "Length: 6061911 (5.8M) [application/octet-stream]\n",
            "Saving to: ‘data-00009-of-00010’\n",
            "\n",
            "data-00009-of-00010 100%[===================\u003e]   5.78M  --.-KB/s    in 0.04s   \n",
            "\n",
            "2021-09-22 14:58:10 (151 MB/s) - ‘data-00009-of-00010’ saved [6061911/6061911]\n",
            "\n"
          ]
        },
        {
          "data": {
            "text/plain": [
              ""
            ]
          },
          "execution_count": 11,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "%%shell\n",
        "for i in `seq 0 9`; do\n",
        "  wget https://storage.googleapis.com/brain-genomics-public/research/proteins/pfam/random_split/test/data-0000$i-of-00010;\n",
        "done"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uC5itbJEFlYG"
      },
      "outputs": [],
      "source": [
        "import glob\n",
        "import pandas as pd\n",
        "test_dfs = []\n",
        "for f_name in glob.glob('data*'):\n",
        "  with open(f_name) as f:\n",
        "    test_dfs.append(pd.read_csv(f))\n",
        "test_df = pd.concat(test_dfs)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "kQgDEIDrGE9J"
      },
      "outputs": [],
      "source": [
        ""
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "FUP4fv_iGa-e"
      },
      "outputs": [],
      "source": [
        "import math\n",
        "def infer(batch):\n",
        "  seq_lens = [len(seq) for seq in batch]\n",
        "  one_hots = [residues_to_one_hot(seq) for seq in batch]\n",
        "  padded_sequence_inputs = [pad_one_hot_sequence(seq, max(seq_lens)) for seq in one_hots]\n",
        "  with graph.as_default():\n",
        "    return sess.run(\n",
        "        top_pick_signature_tensor_name,\n",
        "        {\n",
        "            sequence_input_tensor_name: padded_sequence_inputs,\n",
        "            sequence_lengths_input_tensor_name: seq_lens,\n",
        "        })"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Npguo8s04Hvx"
      },
      "outputs": [],
      "source": [
        "# Sort test_df by sequence length so that batches have as little padding as \n",
        "# possible -\u003e faster inference.\n",
        "test_df = test_df.sort_values('sequence', key=lambda col: [len(c) for c in col])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "SYxOSJ6A53P8"
      },
      "source": [
        "# Predict domain Pfam labels for 126 thousand domains"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "NyGGDdtoGQRe",
        "outputId": "ab5cf6c4-e7c1-470c-c59f-e0c5665f6ce8"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "100%|██████████| 3943/3943 [20:20\u003c00:00,  3.23it/s]\n"
          ]
        }
      ],
      "source": [
        "inference_results = []\n",
        "batches = list(batch_iterable(test_df.sequence, 32))\n",
        "for seq_batch in tqdm.tqdm(batches, position=0):\n",
        "  inference_results.extend(infer(seq_batch))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "rvpO_Ry1IHMI"
      },
      "outputs": [],
      "source": [
        "test_df['predicted_label'] = [vocab[i] for i in inference_results]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "zqx0A1nAIkKb"
      },
      "outputs": [],
      "source": [
        "# Convert true labels from PF00001.21 to PF00001\n",
        "test_df['true_label'] = test_df.family_accession.apply(lambda s: s.split('.')[0])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "uKBt0BFY59VT"
      },
      "source": [
        "# Compute accuracy"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZFupt5VQA7V-"
      },
      "source": [
        "Reproduces 5th row of figure 1A"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "9IrXtJhHIZN3",
        "outputId": "c9180bc2-a2b3-4c72-ce65-07ab1051092a"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "family calling error rate (percentage) = 0.495\n"
          ]
        }
      ],
      "source": [
        "print('family calling error rate (percentage) = {:.03f}'.format(100-sum(test_df.true_label == test_df.predicted_label) / len(test_df) * 100))"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "collapsed_sections": [],
      "name": "Neural network accuracy on random seed split",
      "provenance": [],
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
